import torch
import torch.nn as nn

from models.mish import Mish


class DNN(nn.Module):
    def __init__(self, d,num_class):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(d, 1024),
                                nn.Dropout(0.3),
                                Mish(),
                                nn.Linear(1024, 512),
                                nn.Dropout(0.3),
                                Mish(),
                                nn.Linear(512, 256),
                                nn.Dropout(0.2),
                                Mish(),
                                nn.Linear(256, num_class)
                                )


    def forward(self, x):
        return self.fc(x)
